{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b8805d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from collections import defaultdict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.set();\n",
    "\n",
    "from dpipe.io import load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a7555b27",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_stats_spottune(exp_path, fold='inference'):\n",
    "    p = torch.load(Path(exp_path) / f'policy_{fold}_record/policy_record')\n",
    "    f = open(Path(exp_path) / f'policy_{fold}_record/iter_record', 'r')\n",
    "    n_iter = f.read()\n",
    "    f.close()\n",
    "    record = (p / int(n_iter)).detach().numpy()\n",
    "    return record"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "507f8815",
   "metadata": {},
   "outputs": [],
   "source": [
    "ns = np.arange(25)\n",
    "\n",
    "# '8' has some problems with NaN policy...\n",
    "slices2folder = {'5': '1_48', '12': '1_24', '24': '1_12', '45': '1_6',\n",
    "                 '90': '1_3', '270': '1', '540': '2', '800': '3'}\n",
    "path_template = '/gpfs/data/gpfs0/b.shirokikh/experiments/da/miccai2021_spottune/finetune/{}/spottune_new/'\n",
    "\n",
    "records = defaultdict(list)\n",
    "for k, v in slices2folder.items():\n",
    "    for n in ns:\n",
    "        records[k].append(get_stats_spottune(Path(path_template.format(v)) / f'experiment_{n}'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2067eab0",
   "metadata": {},
   "outputs": [],
   "source": [
    "records_mean = {k: 1 - np.mean(v, axis=0) for k, v in records.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b158fe42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: switch to new (~lighter) palette\n",
    "# sns.cubehelix_palette(start=2, rot=0, dark=0, light=.95, reverse=True)\n",
    "# sns.color_palette(\"dark:salmon_r\")  # remove _r\n",
    "# "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d70a04fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def floats2colors(floats, bins=256):\n",
    "    palette = sns.color_palette(palette='magma', n_colors=bins).as_hex()\n",
    "    ints = np.int64(floats * (bins - 1))\n",
    "    colors = [palette[i] for i in ints]\n",
    "    return colors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f724973",
   "metadata": {},
   "source": [
    "### colors per blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8a4eb0ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1 5          12         24         45         90         270        540        800       \n",
      " 1 #000004    #010005    #010005    #050416    #000004    #050416    #000004    #671b80   \n",
      " 2 #fcfdbf    #fecd90    #fde7a9    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf   \n",
      " 3 #000004    #010005    #621980    #221150    #06051a    #02020b    #120d31    #f3655c   \n",
      " 4 #000004    #010005    #241253    #1c1044    #120d31    #19103f    #271258    #fd9467   \n",
      " 5 #010106    #000004    #050416    #02020d    #000004    #020109    #000004    #feae77   \n",
      " 6 #fcecae    #febb81    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf   \n",
      " 7 #000004    #000004    #19103f    #040414    #000004    #020109    #150e38    #df4a68   \n",
      " 8 #010005    #000004    #a8327d    #902a81    #221150    #6a1c81    #5f187f    #fcf0b2   \n",
      " 9 #000004    #000004    #160f3b    #150e38    #020109    #110c2f    #0b0924    #fdda9c   \n",
      "10 #fb8761    #ee5b5e    #fcfbbd    #fcfbbd    #fec488    #fcecae    #feb078    #fcfbbd   \n",
      "11 #02020d    #02020b    #d8456c    #932b80    #2f1163    #792282    #57157e    #fed799   \n",
      "12 #000004    #000004    #331067    #07061c    #000004    #1d1147    #19103f    #fcfdbf   \n",
      "13 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #f66c5c   \n",
      "14 #992d80    #782281    #fde2a3    #fddea0    #fa7f5e    #fed194    #febf84    #fcfdbf   \n",
      "15 #000004    #000004    #03030f    #030312    #000004    #000004    #010106    #fde9aa   \n",
      "16 #000004    #000004    #000004    #000004    #000004    #000004    #06051a    #fed799   \n",
      "17 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #000004   \n",
      "18 #000004    #000004    #221150    #2c115f    #000004    #110c2f    #1e1149    #fde2a3   \n",
      "19 #000004    #000004    #000004    #000004    #000004    #000004    #030312    #fcfdbf   \n",
      "20 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #e55064   \n",
      "21 #000004    #000004    #08071e    #0a0822    #010005    #050416    #140e36    #fcfdbf   \n",
      "22 #000004    #000004    #000004    #000004    #000004    #000004    #010005    #c83e73   \n",
      "23 #000004    #000004    #02020d    #000004    #000004    #000004    #07061c    #fcfbbd   \n",
      "24 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #feca8d   \n",
      "25 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #5a167e   \n",
      "26 #000004    #000004    #000004    #040414    #000004    #040414    #140e36    #fde0a1   \n",
      "27 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #e85362   \n",
      "28 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #febb81   \n",
      "29 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #21114e   \n",
      "30 #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf    #fcfdbf   \n",
      "31 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #010005   \n",
      "32 #000004    #000004    #000004    #000004    #000004    #000004    #000004    #000004   \n"
     ]
    }
   ],
   "source": [
    "colors = {k: floats2colors(v) for k, v in records_mean.items()}\n",
    "print(-1, *[f'{k:10s}' for k in colors.keys()])\n",
    "for i in range(1, len(colors['5']) + 1):\n",
    "    print(f'{i:2d}', *[f'{v[i - 1]:10s}' for _, v in colors.items()])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788da9cf",
   "metadata": {},
   "source": [
    "### colorbar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f339bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = floats2colors(np.linspace(0, 1, num=11))\n",
    "for i, c in enumerate(colors, start=1):\n",
    "    print(i, c)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
